{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pretty printing has been turned OFF\n",
      "Pretty printing has been turned ON\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pprint\n",
    "from Environment import GridworldEnv\n",
    "from pprint import PrettyPrinter\n",
    "\n",
    "%pprint\n",
    "pp = PrettyPrinter(indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_action_value(state, V, discount_factor=1.0):\n",
    "        \"\"\"\n",
    "        Calculate the expected value of each action in a given state.\n",
    "        对于给定的状态 s 计算其动作 a 的期望值\n",
    "        \"\"\"\n",
    "    A = np.zeros(env.nA)\n",
    "    for a in range(env.nA):\n",
    "        for prob, next_state, reward, done in env.P[state][a]:\n",
    "            A[a] += prob * (reward + discount_factor * V[next_state])\n",
    "    return A\n",
    "    \n",
    "    \n",
    "def value_iteration(env, theta=0.1, discount_factor=1.0):\n",
    "    \"\"\"\n",
    "    Value Iteration Algorithm. 值迭代算法\n",
    "    \"\"\"\n",
    "    # 初始化状态值\n",
    "    V = np.zeros(env.nS)\n",
    "\n",
    "    # 迭代计算找到最优的状态值函数 optimal value function\n",
    "    for _ in range(50):\n",
    "        delta = 0 # 停止标志位\n",
    "        \n",
    "        # 计算每个状态的状态值\n",
    "        for s in range(env.nS):\n",
    "            A = calc_action_value(s, V) # 执行一次找到当前状态的动作期望\n",
    "            best_action_value = np.max(A) # 选择最好的动作期望作为新的状态值\n",
    "            \n",
    "            # 计算停止标志位\n",
    "            delta = max(delta, np.abs(best_action_value - V[s])) \n",
    "            \n",
    "            # 更新状态值函数\n",
    "            V[s] = best_action_value  \n",
    "            \n",
    "        if delta < theta:\n",
    "            break\n",
    "    \n",
    "    \n",
    "    # 输出最优策略：通过最优状态值函数找到决定性策略\n",
    "    policy = np.zeros([env.nS, env.nA]) # 初始化策略\n",
    "    \n",
    "    for s in range(env.nS):\n",
    "        # 执行一次找到当前状态的最优状态值的动作期望 A\n",
    "        A = calc_action_value(s, V)\n",
    "        \n",
    "        # 选出状态值最大的作为最优动作\n",
    "        best_action = np.argmax(A)\n",
    "        policy[s, best_action] = 1.0\n",
    "    \n",
    "    return policy, V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\n",
      "[[1 1 2 3]\n",
      " [2 0 2 0]\n",
      " [1 1 2 0]\n",
      " [0 1 1 0]]\n",
      "\n",
      "Reshaped Grid Value Function:\n",
      "[[ 38.  40.  42.  41.]\n",
      " [ 40. -50.  44. -50.]\n",
      " [ 42.  44.  46. -50.]\n",
      " [-50.  46.  48.  50.]]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "env = GridworldEnv()\n",
    "policy, v = value_iteration(env)\n",
    "\n",
    "print(\"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\")\n",
    "print(np.reshape(np.argmax(policy, axis=1), env.shape))\n",
    "print(\"\")\n",
    "\n",
    "print(\"Reshaped Grid Value Function:\")\n",
    "print(v.reshape(env.shape))\n",
    "print(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
